# First import the library
import pyrealsense2 as rs

# Import OpenCV and numpy
import cv2
import numpy as np
from math import tan, pi
from threading import Lock

def get_extrinsics(src, dst):
    extrinsics = src.get_extrinsics_to(dst)
    R = np.reshape(extrinsics.rotation, [3,3]).T
    T = np.array(extrinsics.translation)
    return (R, T)

def camera_matrix(intrinsics):
    return np.array([[intrinsics.fx,             0, intrinsics.ppx],
                     [            0, intrinsics.fy, intrinsics.ppy],
                     [            0,             0,              1]])

def fisheye_distortion(intrinsics):
    return np.array(intrinsics.coeffs[:4])

viz_lock  = Lock()
viz_data = {"img"  : None,
              "timestamp_ms" : None}

class RSFrameData():
    def __init__(self, leftimg,timestamp_ms):
        self._leftimg = leftimg
        self._timestamp_ms = timestamp_ms
class RSIMUData():
    def __init__(self, acc,veclocity,timestamp_ms):
        self._acc = acc
        self._vec = veclocity
        self._timestamp_ms = timestamp_ms

class RSDataCollector():
    def __init__(self):
        self._pipe = rs.pipeline()
        self._cfg = rs.config()
        self._pipe.start(cfg, self.callback)
        self._frame_mutex = Lock()
        self._imu_mutex = Lock()
        self._framesdata= []
        self._imusdata  = []        
        profiles = self._pipe.get_active_profile()
        stream = profiles.get_stream(rs.stream.fisheye, 1).as_video_stream_profile()
        intrinsics = stream.get_intrinsics()
        print("Camera:",  intrinsics)
        # Translate the intrinsics from librealsense into OpenCV
        self._K  = camera_matrix(intrinsics)
        self._D  = fisheye_distortion(intrinsics)
        (self._width, self._height) = (intrinsics.width, intrinsics.height)
        P = np.array([[intrinsics.fx,             0, intrinsics.ppx,0],
                            [            0, intrinsics.fy, intrinsics.ppy,0],
                            [            0,             0,              1,0]])
        (self._mpx, self._mpy) = cv2.fisheye.initUndistortRectifyMap(self._K, self._D, np.eye(3), 
            P,(self._width, self._height) , cv2.CV_32FC1)


    def callback(frame):
        global viz_data
        if frame.is_frameset():
            frameset = frame.as_frameset()
            f1 = frameset.get_fisheye_frame(1).as_video_frame()
            left_data = np.asanyarray(f1.get_data())
            left_data = cv2.remap(src = left_data,
                                      map1 = self._mpx,
                                      map2 = self._mpy,
                                      interpolation = cv2.INTER_LINEAR)
            ts = frameset.get_timestamp()
            viz_data["img"]=left_data.copy()
            viz_data["timestamp_ms"]=ts
            self._frame_mutex.acquire()
            self._framesdata.append(RSFrameData(left_data,ts))
            self._frame_mutex.release()
        elif frame.is_pose_frame():
            poseset = frame.as_pose_frame()
            posedata = poseset.get_pose_data()
            self._imu_mutex.acquire()
            self._imusdata.append(RSIMUData(posedata.acceleration,posedata.angular_velocity,poseset.get_timestamp))
            self._imu_mutex.release()
            

rsdatas = RSDataCollector()

WINDOW_TITLE = 'Realsense'
cv2.namedWindow(WINDOW_TITLE, cv2.WINDOW_NORMAL)

window_size = 5

try:
    while True:
        cv2.imshow(WINDOW_TITLE, viz_data["img"])
    key = cv2.waitKey(1)
    if key == ord('q'):
        pipe.stop()
        break
finally:
    pipe.stop()

def write_realsense_data(data : RSDataCollector , imu_file : str , image_file : str, image_dir : dir):
    file_csv = []
    for imu  in data._imusdata:
        file_csv.append(','.join([imu._timestamp_ms,imu._acc[0],imu._acc[1],imu._acc[2],imu._vec[0],imu._vec[1],imu._vec[2]]))
    file_csv = "\n".join(file_csv)
    with open(imu_file,'w') as f:
        f.write(file_csv)

    file_csv = []
    for image_timestamp in data._framesdata:
        file_csv.append(image_timestamp._timestamp_ms)
        cv2.imwrite(image_dir+str(image_timestamp._timestamp_ms)+".png",image_timestamp._leftimg)
    with open(image_file,'w') as f:
        f.write(file_csv)
    